Pytorch Tensor形状变换

之前一直知道Pytorch、numpy之类有好几种变换维度的方法,但是之前了解的不深,今天总结一下。

Reshape

Pytorch在reshape的时候,是按照行存储的方式进行的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch

In [16]: x = torch.randn(2, 3, 4)

In [17]: x
Out[17]:
tensor([[[ 0.2921, -0.1806, -1.0838, -0.6770],
[-0.7797, 0.2614, -0.5380, 1.8941],
[-0.7261, 0.8209, 0.0286, 0.0997]],

[[ 0.5549, 0.9036, 0.8790, -0.5776],
[ 0.3745, 0.6963, -0.3445, -0.0022],
[-0.5991, 1.4639, -0.5396, -0.1702]]])

In [19]: torch.reshape(x, (-1, 2))
Out[19]:
tensor([[ 0.2921, -0.1806],
[-1.0838, -0.6770],
[-0.7797, 0.2614],
[-0.5380, 1.8941],
[-0.7261, 0.8209],
[ 0.0286, 0.0997],
[ 0.5549, 0.9036],
[ 0.8790, -0.5776],
[ 0.3745, 0.6963],
[-0.3445, -0.0022],
[-0.5991, 1.4639],
[-0.5396, -0.1702]])

permute

permute会将tensor维度进行调整。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
In [24]: y = x.permute(1, 2, 0)

In [25]: y
Out[25]:
tensor([[[ 0.2921, 0.5549],
[-0.1806, 0.9036],
[-1.0838, 0.8790],
[-0.6770, -0.5776]],

[[-0.7797, 0.3745],
[ 0.2614, 0.6963],
[-0.5380, -0.3445],
[ 1.8941, -0.0022]],

[[-0.7261, -0.5991],
[ 0.8209, 1.4639],
[ 0.0286, -0.5396],
[ 0.0997, -0.1702]]])

In [26]: y.size()
Out[26]: torch.Size([3, 4, 2])

In [29]: x[:, 1, 1]
Out[29]: tensor([0.2614, 0.6963])

In [31]: y[1, 1, :]
Out[31]: tensor([0.2614, 0.6963])
------ 本文结束------
坚持原创技术分享,您的支持将鼓励我继续创作!

欢迎关注我的其它发布渠道